feat(mlx): pt.random support with mlx backend#1979
feat(mlx): pt.random support with mlx backend#1979williambdean wants to merge 8 commits intopymc-devs:v3from
Conversation
b620bb1 to
d602268
Compare
ricardoV94
left a comment
There was a problem hiding this comment.
missing rng outputs /updates (so consecutive calls get updated rng)
There should be tests in numba/jax you can use as template. Jax is going to be more similar
| thunk_inputs = [] | ||
| for n in self.fgraph.inputs: | ||
| sinput = storage_map[n] | ||
| if isinstance(sinput[0], Generator): |
There was a problem hiding this comment.
you need to do the same dance jax linker does with shared Generator variables
|
#2010 caused conflicts for this PR. You will need to rebase. |
e6f7371 to
0b4fb85
Compare
pytensor/link/mlx/dispatch/random.py
Outdated
| def sample_fn(rng_key, size, dtype, p): | ||
| p = mx.array(p) | ||
| if size is None: | ||
| shape = p.shape |
There was a problem hiding this comment.
you always need the shape? You didn't need it in the categorical. I would assume you only need when one of the parameters doesn't go in the random function. If so that would take a lot of boilerplate away from your dispatches
There was a problem hiding this comment.
my comment wasn't about Bernoulli specifically, I would expect you don't need to define shape explicitly (when the user didn't do it themselves) most of the time
| return sample_fn | ||
|
|
||
|
|
||
| @mlx_sample_fn.register(ptr.MvNormalRV) |
There was a problem hiding this comment.
MvNormal supports different decomposition strategies, you may want to implement like numba dispatch/op.perform which is more low level if mx.random.multivariate_normal doesn't support them. Or if it's unfeasible issue a warning that it isn't respected and will fallback to svd (if it wasn't svd to begin with)
pytensor/link/mlx/dispatch/random.py
Outdated
| if batch_ndim: | ||
| raise NotImplementedError( | ||
| "MLX random.permutation does not support batch dimensions." | ||
| ) |
There was a problem hiding this comment.
raise at dispatch time already
…ethods, permutation error at dispatch time
0af680e to
5d450de
Compare
Description
Basic support for
mlxrandom generation.They have limited support. Missing Gamma distribution. Could support additional ones
with basic transformations. i.e.
pt.abs(pt.random.normal(...))~ Half NormalMLX Reference: https://ml-explore.github.io/mlx/build/html/python/random.html
Related Issue
Checklist
Type of change